#! /usr/bin/env python3

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os.path as osp
import openslide
from pathlib import Path
from skimage.filters import threshold_otsu
import glob
#before importing HDFStore, make sure 'tables' is installed by pip3 install tables
from pandas import HDFStore
from openslide.deepzoom import DeepZoomGenerator
from sklearn.model_selection import StratifiedShuffleSplit
#import tensorflow as tf

#config = tf.ConfigProto()
#config.gpu_options.allow_growth = True
#session = tf.Session(config=config)

print('Hi, patch extraction can take a while, please be patient...')
slide_path = '/raida/wjc/CAMELYON16/training/tumor'
#slide_path = '/home/wli/Downloads/CAMELYON16/training/tumor'
#anno_path = '/home/wli/Downloads/CAMELYON16/training/Lesion_annotations'
BASE_TRUTH_DIR = '/raida/wjc/CAMELYON16/training/masking'
slide_paths = glob.glob(osp.join(slide_path, '*.tif'))
slide_paths.sort()
BASE_TRUTH_DIRS = glob.glob(osp.join(BASE_TRUTH_DIR, '*.tif'))
BASE_TRUTH_DIRS.sort()
#image_pair = zip(tumor_paths, anno_tumor_paths)  
#image_pair = list(image_mask_pair)

sampletotal = pd.DataFrame([])
i=0
while i < len(slide_paths):
    #sampletotal = pd.DataFrame([])
    base_truth_dir = Path(BASE_TRUTH_DIR)
    slide_contains_tumor = osp.basename(slide_paths[i]).startswith('tumor_')
    
    with openslide.open_slide(slide_paths[i]) as slide:
        thumbnail = slide.get_thumbnail((slide.dimensions[0] / 224, slide.dimensions[1] / 224))
    
        thumbnail_grey = np.array(thumbnail.convert('L')) # convert to grayscale
        thresh = threshold_otsu(thumbnail_grey)
        binary = thumbnail_grey > thresh
    
        patches = pd.DataFrame(pd.DataFrame(binary).stack())
        patches['is_tissue'] = ~patches[0]
        patches.drop(0, axis=1, inplace=True)
        patches['slide_path'] = slide_paths[i]
    
    if slide_contains_tumor:
        truth_slide_path = base_truth_dir / osp.basename(slide_paths[i]).replace('.tif', '_mask.tif')
        with openslide.open_slide(str(truth_slide_path)) as truth:
            thumbnail_truth = truth.get_thumbnail((truth.dimensions[0] / 224, truth.dimensions[1] / 224)) 
        
        patches_y = pd.DataFrame(pd.DataFrame(np.array(thumbnail_truth.convert("L"))).stack())
        patches_y['is_tumor'] = patches_y[0] > 0
        patches_y.drop(0, axis=1, inplace=True)

        samples = pd.concat([patches, patches_y], axis=1)
        #sampletotal.append(pd.concat([patches, patches_y], axis=1))
        
    else:
        samples = patches
        samples['is_tumor'] = False
        #patches['is_tumor'] = False
        #sampletotal.append(patches)
            
       
        
    #if filter_non_tissue:
    samples = samples[samples.is_tissue == True] # remove patches with no tissue
    samples['tile_loc'] = list(samples.index)
    samples.reset_index(inplace=True, drop=True)
        
    sampletotal=sampletotal.append(samples, ignore_index=True)
        
    i=i+1
        
        

# randomly drop normal patches to match the number of tumor patches
idx=sampletotal.index[sampletotal['is_tumor'] == False].tolist()
drop_indices = np.random.choice(idx, 1800000, replace=False)
sampletotal_subset = sampletotal.drop(drop_indices)
# reorder the index. this is important
sampletotal_subset.reset_index(drop=True, inplace=True)

#NUM_SAMPLES = 50000
#sampletotal_subset= sampletotal_subset.sample(NUM_SAMPLES, random_state=42)
 
#sampletotal_subset.reset_index(drop=True, inplace=True)

print(sampletotal_subset.is_tumor.value_counts())

# real picture patches generation function. In the function there is save command to save numpy array da# ta as .npz format which is independent of platform. 

import cv2
from keras.utils.np_utils import to_categorical

NUM_CLASSES = 2 # not_tumor, tumor

def gen_imgs(samples, batch_size, base_truth_dir=BASE_TRUTH_DIR, shuffle=True):
   
    
    num_samples = len(samples)
    while 1: # Loop forever so the generator never terminates
        if shuffle:
            samples = samples.sample(frac=1) # shuffle samples
        
        for offset in range(0, num_samples, batch_size):
            batch_samples = samples.iloc[offset:offset+batch_size]
        
            images = []
            masks = []
            for _, batch_sample in batch_samples.iterrows():
                slide_contains_tumor = osp.basename(batch_sample.slide_path).startswith('tumor_')
                
                with openslide.open_slide(batch_sample.slide_path) as slide:
                    tiles = DeepZoomGenerator(slide, tile_size=256, overlap=0, limit_bounds=False)
                    img = tiles.get_tile(tiles.level_count-1, batch_sample.tile_loc[::-1])
                    
                # only load truth mask for tumor slides
                if slide_contains_tumor:
                    truth_slide_path = osp.join(base_truth_dir, osp.basename(batch_sample.slide_path).replace('.tif', '_mask.tif'))
                    with openslide.open_slide(str(truth_slide_path)) as truth:
                        truth_tiles = DeepZoomGenerator(truth, tile_size=256, overlap=0, limit_bounds=False)
                        mask = truth_tiles.get_tile(truth_tiles.level_count-1, batch_sample.tile_loc[::-1])
                        mask = (cv2.cvtColor(np.array(mask), cv2.COLOR_RGB2GRAY) > 0).astype(int)
                else:
                    mask = np.zeros((256, 256))

                images.append(np.array(img))
                masks.append(mask)

            X_train = np.array(images)
            y_train = np.array(masks)
            y_train = to_categorical(y_train, num_classes=2).reshape(y_train.shape[0], 256, 256, 2)
            
            #np.save('X_train.npy', X_train)
            #np.save('y_train.npy', y_train)
            
            yield X_train, y_train

#split patches into 20% validation dataset and 80% training dataset
from sklearn.model_selection import StratifiedShuffleSplit

# split samples into train and validation set
# use StratifiedShuffleSplit to ensure both sets have same proportions of tumor patches
split = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
for train_index, test_index in split.split(sampletotal_subset, sampletotal_subset["is_tumor"]):
        train_samples_subset = sampletotal_subset.loc[train_index]
        validation_samples_subset = sampletotal_subset.loc[test_index]


#split whole set of samples to train and validation sets

#from sklearn.model_selection import StratifiedShuffleSplit

# split samples into train and validation set
# use StratifiedShuffleSplit to ensure both sets have same proportions of tumor patches
##split = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
##for train_index, test_index in split.split(sampletotal, sampletotal["is_tumor"]):
##        train_samples = sampletotal.loc[train_index]
##        validation_samples = sampletotal.loc[test_index]


from keras.models import Sequential
from keras.layers import Lambda, Dropout
from keras.layers.convolutional import Convolution2D, Conv2DTranspose
from keras.layers.pooling import MaxPooling2D
from keras.models import model_from_json
from keras.callbacks import ModelCheckpoint

model = Sequential()
model.add(Lambda(lambda x: x / 255.0 - 0.5, input_shape=(256, 256, 3)))
model.add(Convolution2D(100, (5, 5), strides=(2, 2), activation='elu', padding='same'))
model.add(MaxPooling2D())
model.add(Convolution2D(200, (5, 5), strides=(2, 2), activation='elu', padding='same'))
model.add(MaxPooling2D())
model.add(Convolution2D(300, (3, 3), activation='elu', padding='same'))
model.add(Convolution2D(300, (3, 3), activation='elu',  padding='same'))
model.add(Dropout(0.1))
model.add(Convolution2D(2, (1, 1))) # this is called upscore layer for some reason?
model.add(Conv2DTranspose(2, (31, 31), strides=(16, 16), activation='softmax', padding='same'))

model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

BATCH_SIZE = 32
N_EPOCHS = 100

from datetime import datetime


train_generator = gen_imgs(train_samples_subset, BATCH_SIZE)
validation_generator = gen_imgs(validation_samples_subset, BATCH_SIZE)

# checkpoint
filepath="weight_best.hdf5"
checkpoint = ModelCheckpoint(filepath, monitor='val_acc', verbose=1, save_best_only=True, mode='max')
callbacks_list = [checkpoint]
print("OK, training starts...")

# Train model
train_start_time = datetime.now()
model.fit_generator(train_generator, np.ceil(len(train_samples_subset) / BATCH_SIZE),
    validation_data=validation_generator,
    validation_steps=np.ceil(len(validation_samples_subset) / BATCH_SIZE),
    epochs=N_EPOCHS)

train_end_time = datetime.now()
print("Model training time: %.1f minutes" % ((train_end_time - train_start_time).seconds / 60,))

#model.save('model0719.h5')

# serialize model to JSON
model_json = model.to_json()
with open("model.json", "w") as json_file:
    json_file.write(model_json)
# serialize weights to HDF5
#model.save_weights("model.h5")
print("Saving model to disk")

## load json and create model
#json_file = open('model.json', 'r')
#loaded_model_json = json_file.read()
#json_file.close()
#loaded_model = model_from_json(loaded_model_json)
## load weights into new model
#loaded_model.load_weights("model.h5")
#print("Loaded model from disk")


 
## evaluate loaded model on test data
#loaded_model.compile(loss='binary_crossentropy', optimizer='rmsprop', metrics=['accuracy'])
#score = loaded_model.evaluate(X, Y, verbose=0)
#print("%s: %.2f%%" % (loaded_model.metrics_names[1], score[1]*100))
